"""
Phase 6: Robustness & Paper
============================
1. 5-year lagged Z
2. First differences
3. OECD/non-OECD subsample
4. Excluding financial centers (LUX, IRL, HKG, SGP, CHE, NLD, BEL)
Output: lagged_automation.md, first_diff_automation.md, robustness_subsamples.md
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}

FINANCIAL_CENTERS = {'LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'NLD', 'BEL'}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    return result


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── 1. 5-Year Lagged Demographics ──────────────────────────────────

def lagged_automation(df):
    """5-year lagged Z to address simultaneity concerns."""
    print("\n" + "=" * 60)
    print("1. LAGGED DEMOGRAPHICS (5-YEAR LAG)")
    print("=" * 60)

    df = df.sort_values(['iso3', 'year']).copy()

    # Construct 5-year lags
    for var in ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep']:
        df[f'{var}_lag5'] = df.groupby('iso3')[var].shift(5)

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']

    # Determine automation DVs
    auto_dvs = []
    for var in ['capital_intensity', 'labor_productivity', 'capital_per_worker']:
        if var in df.columns and df[var].notna().sum() > 100:
            auto_dvs.append(var)
    if not auto_dvs:
        auto_dvs = ['gross_investment_gdp']

    results = []

    for dv in auto_dvs:
        print(f"\n--- Dependent variable: {dv} ---")

        # Contemporary Z
        r = run_panel_gls(df, dv,
                          ['Z_1', 'Z_2', 'Z_3'] + controls,
                          f'{dv}: Z_t')
        if r: results.append(r)

        # 5-year lagged Z
        r = run_panel_gls(df, dv,
                          ['Z_1_lag5', 'Z_2_lag5', 'Z_3_lag5'] + controls,
                          f'{dv}: Z_t-5')
        if r: results.append(r)

        # Age decomposition: contemporary vs lagged
        r = run_panel_gls(df, dv,
                          ['old_dep', 'youth_dep'] + controls,
                          f'{dv}: Age_t')
        if r: results.append(r)

        r = run_panel_gls(df, dv,
                          ['old_dep_lag5', 'youth_dep_lag5'] + controls,
                          f'{dv}: Age_t-5')
        if r: results.append(r)

    write_table(results, "lagged_automation.md",
                "Contemporary vs. 5-Year Lagged Demographics")

    return df


# ── 2. First Differences ───────────────────────────────────────────

def first_diff_automation(df):
    """First-differenced regressions to test level vs. change effects."""
    print("\n" + "=" * 60)
    print("2. FIRST DIFFERENCES")
    print("=" * 60)

    df = df.sort_values(['iso3', 'year']).copy()

    # Compute first differences
    diff_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                 'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen',
                 'trade_openness']

    # Add automation-specific vars if available
    for var in ['capital_intensity', 'labor_productivity', 'capital_per_worker',
                'gross_investment_gdp']:
        if var in df.columns:
            diff_vars.append(var)

    for var in diff_vars:
        if var in df.columns:
            df[f'd_{var}'] = df.groupby('iso3')[var].diff()

    controls_d = ['d_fiscal_bal_gdp', 'd_nfa_gdp_lag', 'd_rgdp_growth', 'd_kaopen']
    controls_d = [c for c in controls_d if c in df.columns]

    # Determine automation DVs (differenced)
    auto_dvs_d = []
    for var in ['capital_intensity', 'labor_productivity', 'capital_per_worker']:
        dvar = f'd_{var}'
        if dvar in df.columns and df[dvar].notna().sum() > 100:
            auto_dvs_d.append(dvar)
    if not auto_dvs_d:
        if 'd_gross_investment_gdp' in df.columns:
            auto_dvs_d = ['d_gross_investment_gdp']

    results = []

    for dv in auto_dvs_d:
        print(f"\n--- Dependent variable: {dv} ---")

        # dZ → d(automation)
        r = run_panel_gls(df, dv,
                          ['d_Z_1', 'd_Z_2', 'd_Z_3'] + controls_d,
                          f'{dv}: dZ')
        if r: results.append(r)

        # d(age) → d(automation)
        r = run_panel_gls(df, dv,
                          ['d_old_dep', 'd_youth_dep'] + controls_d,
                          f'{dv}: dAge')
        if r: results.append(r)

    # Also: dZ → d(trade_openness)
    if 'd_trade_openness' in df.columns:
        r = run_panel_gls(df, 'd_trade_openness',
                          ['d_Z_1', 'd_Z_2', 'd_Z_3'] + controls_d,
                          'dTrade: dZ')
        if r: results.append(r)

    write_table(results, "first_diff_automation.md",
                "First-Differenced Regressions")


# ── 3. Robustness Subsamples ───────────────────────────────────────

def robustness_subsamples(df):
    """OECD/non-OECD and excl. financial centers."""
    print("\n" + "=" * 60)
    print("3. ROBUSTNESS SUBSAMPLES")
    print("=" * 60)

    df = df.copy()
    df['is_oecd'] = df['iso3'].isin(OECD).astype(int)
    df['is_fc'] = df['iso3'].isin(FINANCIAL_CENTERS).astype(int)

    oecd_df = df[df['is_oecd'] == 1].copy()
    non_oecd_df = df[df['is_oecd'] == 0].copy()
    no_fc_df = df[df['is_fc'] == 0].copy()

    print(f"  Full: {df['iso3'].nunique()} countries, {len(df)} obs")
    print(f"  OECD: {oecd_df['iso3'].nunique()} countries, {len(oecd_df)} obs")
    print(f"  Non-OECD: {non_oecd_df['iso3'].nunique()} countries, {len(non_oecd_df)} obs")
    print(f"  Excl. Fin Centers: {no_fc_df['iso3'].nunique()} countries, {len(no_fc_df)} obs")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']
    x_vars = ['Z_1', 'Z_2', 'Z_3'] + controls

    # Determine automation DVs
    auto_dvs = []
    for var in ['capital_intensity', 'labor_productivity', 'capital_per_worker']:
        if var in df.columns and df[var].notna().sum() > 100:
            auto_dvs.append(var)
    if not auto_dvs:
        auto_dvs = ['gross_investment_gdp']

    results = []

    for dv in auto_dvs:
        print(f"\n--- Dependent variable: {dv} ---")

        for sub_df, label in [(df, 'Full'),
                               (oecd_df, 'OECD'),
                               (non_oecd_df, 'Non-OECD'),
                               (no_fc_df, 'Excl. FC')]:
            r = run_panel_gls(sub_df, dv, x_vars, f'{dv}: {label}')
            if r: results.append(r)

    # Also trade_openness as DV
    print(f"\n--- Dependent variable: trade_openness ---")
    for sub_df, label in [(df, 'Full'),
                           (oecd_df, 'OECD'),
                           (non_oecd_df, 'Non-OECD'),
                           (no_fc_df, 'Excl. FC')]:
        r = run_panel_gls(sub_df, 'trade_openness', x_vars, f'trade: {label}')
        if r: results.append(r)

    write_table(results, "robustness_subsamples.md",
                "Robustness: OECD/Non-OECD and Excluding Financial Centers")


# ── Main ────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 6: ROBUSTNESS & PAPER")
    print("=" * 70)

    df = pd.read_csv(DATA / "automation_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    print(f"Columns: {list(df.columns)}")

    # 1. Lagged demographics
    df = lagged_automation(df)

    # 2. First differences
    first_diff_automation(df)

    # 3. Robustness subsamples
    robustness_subsamples(df)

    print("\n" + "=" * 70)
    print("PHASE 6 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
